-
Notifications
You must be signed in to change notification settings - Fork 16
MTTKRP speedup (1.5x - 6x range) with einsum #454
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Replace the Y, Ur loop
```
V = np.zeros((szn, R))
for r in range(R):
V[:, [r]] = Y[:, :, r].T @ Ur[:, :, r]
```
with
```
V = np.einsum('ijk, ik -> jk', Y, Ur)
```
for numpy vectorization with einsum.
Some profiling results:
shape [20, 30, 40, 50]
---------------------
mode-1 mttkrp
rank-10
new mttkrp correct True
old 0.00010779168870713975
new 6.477038065592448e-05
speedup 1.6642126789366052
-----------------------------------------------------------------------
rank-50
new mttkrp correct True
old 0.0006463527679443359
new 0.000177409913804796
speedup 3.6432731073615052
-----------------------------------------------------------------------
rank-100
new mttkrp correct True
old 0.0006452401479085287
new 0.00010201666090223525
speedup 6.324850688132952
-----------------------------------------------------------------------
rank-200
new mttkrp correct True
old 0.0016609827677408855
new 0.0002986590067545573
speedup 5.561468866418307
-----------------------------------------------------------------------
mode-1 mttkrp
rank-10
new mttkrp correct True
old 0.0004661348130967882
new 0.00024816724989149306
speedup 1.8783091374893253
-----------------------------------------------------------------------
rank-50
new mttkrp correct True
old 0.002485460705227322
new 0.001473824183146159
speedup 1.6864024445043586
-----------------------------------------------------------------------
rank-100
new mttkrp correct True
old 0.006084998448689778
new 0.0025424957275390625
speedup 2.3933170792698175
-----------------------------------------------------------------------
rank-200
new mttkrp correct True
old 0.04606061511569553
new 0.02842871348063151
speedup 1.6202145463626638
-----------------------------------------------------------------------
|
Do you have any results for small rank (eg 2)? really big dimensions? impact on memory? |
|
I think we should probably test the performance on a variety of systems and cases to ensure the improvements are consistent. Small ranks cases are definitely important as well. |
|
@kennykos Yes, thanks for this! Can you please attach your test script here so that other developers can assess your contribution as well? |
Co-authored-by: Nick <24689722+ntjohnson1@users.noreply.github.com>
@jeremy-myers I've added a rank-2 case to my script, the einsum function is still still slightly faster (which makes sense, the serial loop in the old version is smaller). I'm not sure how to measure the intermediate memory usage on numpy's einsum, but reading around stack exchange hints that it avoids creating large intermediate arrays 🤷. |
@kennykos I was hoping you’d have an idea! There are some tools like |
|
On my local machine for the shapes requested the high water memory mark is (on average) between 1-2% different. Sometimes better sometimes worse. I opened a PR here https://github.com/kennykos/pyttb/pull/1/files into @kennykos branch that tried to do a quick benchmarking example around time and memory usage. If we are going to request contributors to do this analysis then I think we should have an example for our expectations. I think the notebook is a reasonable minimum bar but benchmarking has lots of depth to it if we wanted to really dig in. I'll post my notes below for the performance I saw (on macbook pro M4). If we want a wider range of tests I propose @kennykos merges my PR into the branch and updates the notebook (remember to run pre-commit to clear outputs to note anger CI). Otherwise if this demonstration is sufficient we can merge this PR and I will open one with just the mttkrp benchmarking comparison for reference for future contributors. If any of this is controversial whoever objects should create an issue or discussion around benchmarking. |
|
Timing results shape [20, 30, 40, 50]
---------------------
mode-1 mttkrp
rank-2
results equal: True
old 3.467665778266059e-05
new 3.065003289116753e-05
speedup 1.1313742437337946
-----------------------------------------------------------------------
rank-10
results equal: True
old 6.0452355278862844e-05
new 3.512700398763021e-05
speedup 1.7209653092006032
-----------------------------------------------------------------------
rank-50
results equal: True
old 0.00021190113491482206
new 7.843971252441406e-05
speedup 2.7014522120905102
-----------------------------------------------------------------------
rank-100
results equal: True
old 0.0003298653496636285
new 9.889072842068142e-05
speedup 3.335654969193678
-----------------------------------------------------------------------
rank-200
results equal: True
old 0.0006497701009114584
new 0.0001304414537217882
speedup 4.9813160032493915
-----------------------------------------------------------------------
mode-1 mttkrp
rank-2
results equal: True
old 0.00026498900519476994
new 0.0002613597446017795
speedup 1.0138860733833366
-----------------------------------------------------------------------
rank-10
results equal: True
old 0.0005572372012668186
new 0.00037823783026801213
speedup 1.4732455525984032
-----------------------------------------------------------------------
rank-50
results equal: True
old 0.0014047622680664062
new 0.0005112489064534506
speedup 2.747707135084719
-----------------------------------------------------------------------
rank-100
results equal: True
old 0.0024165577358669704
new 0.0007385147942437066
speedup 3.2721859530812827
-----------------------------------------------------------------------
rank-200
results equal: True
old 0.0071366628011067705
new 0.0029048919677734375
speedup 2.456773911140293
----------------------------------------------------------------------- |
|
Memory results shape [20, 30, 40, 50]
---------------------
mode-1 mttkrp
rank-2
results equal: True
old 82.421875
new 83.10069444444444
ratio 0.991831362553796
-----------------------------------------------------------------------
rank-10
results equal: True
old 83.71875
new 83.75347222222223
ratio 0.999585423489905
-----------------------------------------------------------------------
rank-50
results equal: True
old 85.19965277777777
new 85.53125
ratio 0.9961230869159258
-----------------------------------------------------------------------
rank-100
results equal: True
old 88.625
new 88.8125
ratio 0.9978888106966924
-----------------------------------------------------------------------
rank-200
results equal: True
old 96.24652777777777
new 99.08854166666667
ratio 0.9713184406482697
-----------------------------------------------------------------------
mode-1 mttkrp
rank-2
results equal: True
old 99.140625
new 99.140625
ratio 1.0
-----------------------------------------------------------------------
rank-10
results equal: True
old 113.890625
new 115.765625
ratio 0.983803482251316
-----------------------------------------------------------------------
rank-50
results equal: True
old 125.5
new 125.5
ratio 1.0
-----------------------------------------------------------------------
rank-100
results equal: True
old 143.84375
new 143.88194444444446
ratio 0.999734543172933
-----------------------------------------------------------------------
rank-200
results equal: True
old 180.53298611111111
new 180.546875
ratio 0.9999230732246743
----------------------------------------------------------------------- |
|
I would want to see also order-3 and order-5 results. Larger sizes. And more of a mix of sizes. Ideally, we would also test on something that uses GPUs. I would prefer to adopt what einsum is doing under the hood if possible. I actually haven't looked too closely at what mttkrp is doing in pyttb, so it's possible there is room for simple improvements. I will take a look. |
Can you enumerate the range of larger and mixed sizes? If we can specify it then there's a clearer bar for performance considerations.
My understanding is that einsum simply allows us to chain operations on a pre-compiled code path. So while pyttb is still pure python I don't believe that direction helps (I could be incorrect that it is just an order of operations things).
Unless I missed something pyttb won't leverage gpus since numpy doesn't dispatch to them. |
|
I see that we already have einsum along with numpy and that this is just an alternate way of specifying batched matrix-matrix multiplies. I withdraw my objections. |
|
For other sizes (though I don't think it will matter here), I would suggest to 1-2 tensors where there is some large difference in the sizes, e.g., 3 x 1000 x 100. And then some tensors that are basically as large as 1/4 to 1/2 of available memory. |
Replace the Y, Ur loop
with
for numpy vectorization with einsum.
Some profiling results:
shape [20, 30, 40, 50]
%---------------------
mode-1 mttkrp
rank-10
new mttkrp correct True
old 0.00010779168870713975
new 6.477038065592448e-05
speedup 1.6642126789366052
%----------------------------------------------------------------------- rank-50
new mttkrp correct True
old 0.0006463527679443359
new 0.000177409913804796
speedup 3.6432731073615052
%----------------------------------------------------------------------- rank-100
new mttkrp correct True
old 0.0006452401479085287
new 0.00010201666090223525
speedup 6.324850688132952
%----------------------------------------------------------------------- rank-200
new mttkrp correct True
old 0.0016609827677408855
new 0.0002986590067545573
speedup 5.561468866418307
%----------------------------------------------------------------------- mode-1 mttkrp
rank-10
new mttkrp correct True
old 0.0004661348130967882
new 0.00024816724989149306
speedup 1.8783091374893253
%----------------------------------------------------------------------- rank-50
new mttkrp correct True
old 0.002485460705227322
new 0.001473824183146159
speedup 1.6864024445043586
%----------------------------------------------------------------------- rank-100
new mttkrp correct True
old 0.006084998448689778
new 0.0025424957275390625
speedup 2.3933170792698175
%----------------------------------------------------------------------- rank-200
new mttkrp correct True
old 0.04606061511569553
new 0.02842871348063151
speedup 1.6202145463626638
%-----------------------------------------------------------------------
📚 Documentation preview 📚: https://pyttb--454.org.readthedocs.build/en/454/